import PIL.Image as Image
import torch
import random
from torch.utils.data import Dataset
import numpy as np
import re
from tqdm import tqdm
import os

class ImageEncodingDataset(Dataset):
    def __init__(self, encodings, obj_data, split='train',):

        self.encodings = encodings
        self.split = split
        self.obj_data = obj_data
        # assert len(encodings) == len(obj_data)

        # Split the file paths into train and test sets
        # cutoff = int(len(encodings)*0.8)
        
        # all_idxes = np.arange(len(encodings))
        all_idxes = np.arange(2000, 4000)
        cutoff = int(len(all_idxes)*0.8)
        np.random.shuffle(all_idxes)
        train_idxes, test_idxes = all_idxes[:cutoff], all_idxes[cutoff:]
        if self.split == 'train':
            self.idxes = train_idxes
        elif self.split == 'val':
            self.idxes = test_idxes
        elif self.split == 'full':
            self.idxes = all_idxes
        else:
            raise ValueError("Invalid split value. Must be either 'train' or 'val'.")
        print(" ---- ImageEncodingDataset ---- ")
        print("total timesteps: ", len(encodings))
        print("train timesteps: ", len(train_idxes))


    def __len__(self):
        return len(self.idxes)

    def __getitem__(self, idx):
        frame_idx = self.idxes[idx]
        encodings = self.encodings[frame_idx]
        obj, encoding = random.choice(list(encodings.items()))
        encoding = torch.FloatTensor(encoding)
        target = torch.FloatTensor(self.obj_data[frame_idx][obj][:4])
        return encoding, target